syntax = "proto3";

package tensorflow.tpu;

import "google/protobuf/wrappers.proto";
import "tensorflow/compiler/xla/service/hlo.proto";

message ClippingLimits {
  google.protobuf.FloatValue lower = 1;  // -inf if not set
  google.protobuf.FloatValue upper = 2;  // +inf if not set
}

// Configuration for simulated quantization; simulated quantization is used to
// reduce training/serving skew when the serving variables are quantized. The
// same quantization operations are executed during training to minimize
// differences with serving.
//
// Simulated quantization inserts the following operations on the forward pass
// after gathering the embedding vector from HBM. The backward pass operations
// are unchanged.
//
// clipped_val = clip(input, clipping_limits)
// quantum = clipping_limits.range() / (num_buckets - 1)
// quantized_val = floor((clipped_val - clipping_limits.lower()) / quantum + .5)
// return quantized_val * quantum + clipping_limits.lower().
message SimulatedQuantization {
  // Whether simulated quantization is enabled.
  bool enabled = 1;

  // Minimum and maximum values of the range used for quantization.
  ClippingLimits clipping_limits = 2;

  // Number of possible quantized values.
  int32 num_buckets = 3;
}

// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The
// actual learning rates are provided as a scalar input list to the
// SendTPUEmbeddingGradients Op indexed by their tag specified through the
// following proto.
message DynamicLearningRate {
  // For tables where learning rates are dynamically computed and communicated
  // to the TPU embedding program, a tag must be specified for the learning
  // rate.
  //
  // The tag must be a non-negative  integer. The total number of unique tags
  // must be less than or equal to the number of tables in the TPU embedding
  // configuration (a table does not specify any tag if it uses a constant
  // learning rate, and specifies exactly one tag if it uses dynamic learning
  // rates).
  //
  // All tags in the range [0, number_of_unique_tags) must be present in the TPU
  // embedding configuration, i.e. a tag cannot be skipped if a different tag
  // numerically greater than it is used in the configuration.
  //
  // If multiple tables specify the same tag, they *MUST* have
  // the same dynamic learning rate, for example, their dynamic learning rate
  // could be computed by the same TensorFlow sub-graph. The partitioning of the
  // embedding layer would be more optimal if the number_of_unique_tags is as
  // *LOW* as possible, i.e., if many tables share the same tag.
  //
  // The learning_rate input of the SendTPUEmbeddingGradients op is used to
  // communicate dynamic learning rates to the TPU embedding program.
  // The learning_rate input is a list of scalars where the size of the list is
  // equal to the number of unique tags. The learning rate associated with a
  // particular tag is specified by populating its corresponding index in the
  // list of learning_rate scalars.
  int32 tag = 1;
}

// Source of learning rate to use.
message LearningRate {
  oneof learning_rate {
    float constant = 1;
    DynamicLearningRate dynamic = 2;
  }
}

// Each optimizer's parameter proto has a link to its documentation and CPU
// implementation (if available) for user reference.

// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adagrad
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1634
message AdagradParameters {
  // Old initial accumulator parameter.
  reserved "initial_accumulator";
  reserved 1;
}

// This optimizer combines the Adagrad and Momentum update rules.
// accum(new) = beta2 == 1.0 ?
//              accum(old) + grad^2 :
//              beta2 * accum(old) + (1 - beta2) * grad^2
// accum_with_exponent = (accum(new) + epsilon)^(-1.0 / exponent)
// mom_accum(new) = momentum * mom_accum(old) + accum_with_exponent
// update = use_nesterov ?
//          momentum * mom_accum(new) + accum_with_exponent :
//          mom_accum(new)
// var(new) = var(old) - lr * grad * update
// Algorithm described in https://arxiv.org/abs/2002.11803.
message AdagradMomentumParameters {
  // Moving average parameter for the momentum accumulator.
  float momentum = 1;

  // Whether to use the Nesterov variant of momentum.
  bool use_nesterov = 2;

  // Exponent for the gradient^2 accumulator.
  float exponent = 3;

  // Moving average parameter for the gradient^2 accumulator.
  float beta2 = 4;

  // Offset added to the Adagrad accumulator.
  float epsilon = 5;
}

// Algorithm in http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
message BoundedAdagradParameters {
  // Whether to use the updated or the old value of the accumulator when
  // computing the effective learning rate. When update_accumulator_first is set
  // to True, the updated value of the accumulator is used.
  bool update_accumulator_first = 1;

  // The max_var_update value to use. Set value to 0 (default) to disable using
  // max_var_update to clip the gradient.
  float max_var_update = 2;

  // The maximum value of the accumulator. Set max_accumulator to 0 (default)
  // to disable using max_accumulator to clip the accumulator.
  float max_accumulator = 3;
}

// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L629
message StochasticGradientDescentParameters {}

// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L2646
//
// The hyperparameters for FTRL are the same as for the Keras implementation,
// with some additions. The "beta" parameter matches the behavior described in
// the second link above; "beta" / (2 * learning rate) should be added to "l2"
// to get equivalent behavior in the other TensorFlow implementations of this
// optimizer. When the multiply_linear_by_lr field is set to true, a modified
// formula is used for FTRL that treats the "linear" accumulator as being
// pre-multiplied by the learning rate (i.e., the accumulator named "linear"
// actually stores "linear * learning_rate"). Other than checkpoint
// compatibility, this is mathematically equivalent for a static learning rate;
// for a dynamic learning rate, it is nearly the same as long as the learning
// rate does not change quickly. The benefit of setting multiply_linear_by_lr to
// true is that the modified formula handles zero and near-zero learning rates
// without producing NaNs, improving flexibility for learning rate ramp-up.
message FtrlParameters {
  float l1 = 1;
  float l2 = 2;
  float lr_power = 3;
  float beta = 7;
  bool multiply_linear_by_lr = 6;

  // Previously, allow_zero_accumulator parameter changed some internal formulas
  // to allow zero and near-zero accumulator values at the cost of some
  // performance. The current implementation ignores this parameter; zero or
  // near-zero accumulator values are now always supported.
  bool allow_zero_accumulator = 8 [deprecated = true];

  // Old initial accumulator parameters.
  reserved "initial_accum", "initial_linear";
  reserved 4, 5;
}

// The Adam optimizer does not implement hyper-parameter update due to hardware
// limitations; use the dynamic learning rate feature instead, setting the
// learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// Here, t is the current timestep.
//
// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam
// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L32
//
// Note that the code by default implements the lazy version of Adam
// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer)
// unless the use_non_lazy_adam parameter is set, in which case it implements
// the normal version of Adam that updates all parameters in the embedding
// table, even for entries that are not used in the current minibatch
// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
// use_non_lazy_adam is enabled, gradient accumulation is also required to be
// enabled in order to get correct results; a warning will be printed otherwise
// (which may change to an error in the future). If use_sum_inside_sqrt is set,
// the Adam variable update formula will be changed from m / (sqrt(v) + epsilon)
// to m / sqrt(v + epsilon**2); this option improves the performance of TPU
// training and is not expected to harm model quality.
message AdamParameters {
  float beta1 = 3;
  float beta2 = 4;
  float epsilon = 5;
  bool use_non_lazy_adam = 8;
  bool use_sum_inside_sqrt = 10;

  // Old initial accumulator parameters.
  reserved "initial_m", "initial_v";
  reserved 6, 7;
}

// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L3068
message MomentumParameters {
  float momentum = 1;
  bool use_nesterov = 2;

  // Old initial accumulator parameter.
  reserved "initial_accum";
  reserved 3;
}

// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Lion
// momenta(new) = beta2 * momenta(old) + (1 - beta2) * grad
// momenta_t = beta1 * momenta(old) + (1 - beta1) * grad
// var(new) = var(old) - lr * sign(momenta_t)
// Algorithm described in https://arxiv.org/abs/2302.06675.
message LionParameters {
  float beta1 = 1;
  float beta2 = 2;
  bool use_non_lazy_lion = 3;
}

// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4229
message RmsPropParameters {
  float rho = 1;
  float momentum = 2;
  float epsilon = 3;

  // Old initial accumulator parameters.
  reserved "initial_ms", "initial_mom";
  reserved 4, 5;
}

// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4358
message CenteredRmsPropParameters {
  float rho = 1;
  float momentum = 2;
  float epsilon = 3;

  // Old initial accumulator parameters.
  reserved "initial_ms", "initial_mom", "initial_mg";
  reserved 4, 5, 6;
}

// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
message MdlAdagradLightParameters {
  float l2 = 1;
  float lr_power = 2;
  float min_servable_mdl_benefit = 3;
  float mdl_mix_in_margin = 4;
  float mdl_benefit_rampup_coeff = 5;
  float mdl_min_weight = 6;
  float benefit_revisit_scale = 7;
  float max_event_benefit = 8;
  float max_total_benefit = 9;
  float mdl_hard_limit = 10;
  bool hard_limit_min_benefit = 11;
  bool mdl_regularize = 12;

  // Old initial accumulator parameters.
  reserved "initial_accumulator", "initial_weight", "initial_benefit";
  reserved 13, 14, 15;
}

// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adadelta
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L933
message AdadeltaParameters {
  float rho = 1;
  float epsilon = 2;

  // Old initial accumulator parameters.
  reserved "initial_accumulator", "initial_update";
  reserved 3, 4;
}

// https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/ProximalAdagradOptimizer
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1961
message ProximalAdagradParameters {
  float l1 = 1;
  float l2 = 2;

  // Old initial accumulator parameter.
  reserved "initial_accumulator";
  reserved 3;
}

// The online Yogi optimizer does not implement hyper-parameter update; use the
// dynamic learning rate feature instead, setting the learning rate to:
// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// Here, t is the current timestep.
//
// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
// plus some extensions based on FTRL.
//
// Note that the code by default implements the lazy version of online Yogi.
message OnlineYogiParameters {
  // The L1 regularization parameter (used analogously to the one in FTRL).
  float l1 = 1;

  // The L2 regularization parameter (used analogously to the one in FTRL).
  float l2 = 2;

  // \beta_2 from Algorithm 2 in the paper.
  float beta2 = 3;

  // Reserved ids corresponding to removed tanh activation.
  reserved 6;  // sign
  reserved 7;  // tanh
}

// The online Yogi optimizer does not implement hyper-parameter update; use the
// dynamic learning rate feature instead, setting the learning rate to:
// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// Here, t is the current timestep.
//
// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
// plus some extensions based on FTRL.
//
// Note that the code by default implements the lazy version of proximal Yogi.
message ProximalYogiParameters {
  // The L1 regularization parameter.
  float l1 = 1;

  // The L2 regularization parameter.
  float l2 = 2;

  // The exponential decay rate for the 1st moment estimates.
  float beta1 = 3;

  // The exponential decay rate for the 2nd moment estimates.
  float beta2 = 4;

  // A constant trading off adaptivity and noise.
  float epsilon = 5;

  // Reserved ids corresponding to removed tanh activation.
  reserved 8;  // sign
  reserved 9;  // tanh
}

// Estimator for the frequency of updates to a lookup table. It maintains an
// array (tf.Variable) D, where each element records the average number of
// global steps between two consecutive batches that hit the corresponding
// bucket. Once an item with bucket id i is sampled, D[i] is updated by:
//   D[i] <- D[i] * (1 - tau) + delta[i] * tau,
//
// where tau is a learning rate between 0 and 1 (exclusive), and
//   delta[i] = current global step - last step i is sampled.
//
// The estimated frequency (sampling rate in a batch) is thus 1 / D[i].
//
// Elements in D are initialized with a large value max_delta. delta[i] will
// also be capped by this value.
//
// The exact sequence of operations used in the optimizer is shown below.
// last_hit_step[i] is a tf.Variable that holds the last global step at which i
// was sampled.
//
//   delta = global_step - last_hit_step[i]
//   clipped_delta = min(delta, params.max_delta)
//   is_outlier = (delta >= params.outlier_threshold * D[i])
//   D[i] <- is_outlier ? clipped_delta
//                      : D[i] * (1 - params.tau) + clipped_delta * params.tau
//   last_hit_step[i] <- global_step
message FrequencyEstimatorParameters {
  // Learning rate between (0, 1) that is used to update the array D.
  float tau = 1;

  // Maximum value of delta: difference between the current global step and the
  // last global step at which the row was sampled.
  float max_delta = 2;

  // Threshold used to determine whether the current update is an outlier.
  float outlier_threshold = 3;

  // The weight exponent used to transform the estimated delta into weights.
  // The transformation function is: (delta / max_delta) ^ (weight_exponent)
  float weight_exponent = 4;
}

// A user-defined optimizer.
// The contained HLO program must take the following arguments in the following
// order:
// 1.  gradients
// 2.  table weights
// 3.  slot variables
// 4.  an optional scalar input that is passed in via the dynamic learning
//     rate mechanism.
//
// It must return/end in a tuple op that contains the following values in the
// following order:
// 1.  new table values
// 2.  new slot variable value
//
// The program must have shape (1,1) with dtype float32 throughout and only use
// HLO that operate elementwise (e.g., no reduce, no variables, no control flow
// and no broadcasting outside of the single scalar input).
// The HLO program should be written as if it were a dense update. It will be
// called on each row that needs an update and will applied elementwise.
message UserDefinedProgramParameters {
  xla.HloModuleProto program = 1;
  reserved 2;  // Was padding_values
}

// Optimizer that just sets the variable to the value of the gradient. To be
// correct, this requires either gradient accumulation (to sum the values of a
// computed expression across the samples) or to deduplicate IDs within a single
// host (to assign the value from an arbitrary sample).
message AssignParameters {}

// Status of using gradient accumulation (doing two passes over the input
// gradients: one to accumulate them into a temporary array and another to apply
// them using the actual optimization algorithm). The extra message is to wrap
// the enum for scoping.
message GradientAccumulationStatus {
  // if UNSPECIFIED (default), gradient accumulation is ENABLED.
  enum Status {
    UNSPECIFIED = 0;
    ENABLED = 1;
    DISABLED = 2;
  }
}

// Whether to optimize the packing of low-dimensional embedding tables in HBM
// (high bandwidth memory). TPUs access HBM at 32-byte (8-float) granularity.
// For functional correctness, the TPU software internally pads the embedding
// dimension to a multiple of 8. This can sometimes lead to significant memory
// wastage due to padding. For 1-dimensional, 2-dimensional, and 4-dimensional,
// the TPU software can remove this padding by packing multiple rows into the
// same 8-float HBM chunk. For example, 8 rows could be packed into the same
// 8-float chunk for a 1-dimensional embedding table.

// There is one important limitation for this HBM packing though. When only a
// subset of rows in an 8-float chunk are accessed on a particular step, the
// adjoining rows in the same chunk are updated with zero gradients on the
// backward pass even if they are not touched. This is an artifact of the
// packing implementation. This operation is NOT functionally correct for
// optimizers where zero gradients change the embeddings/slot-variable values,
// e.g., momentum-based optimizers. Hence, this HBM packing cannot be enabled
// for embedding tables with such optimizers. The TPU software automatically
// recognizes that a zero gradient can modify state and turns off the low
// dimensional embedding packing in that scenario.
//
// However, for optimizers where a zero gradient is a NoOp, such as SGD,
// Adagrad, and FTRL, this packing optimization can be used. However, there are
// some important considerations:
// * Clipping limits: The initial values for such embeddings should fall within
//   the clipping limits specified in the optimization parameters. Otherwise, a
//   zero gradient will cause the embeddings to be clipped. This changes state
//   and hence, is not a NoOp.
// * FTRL: The embedding vector is computed directly from the values of the
//   accumulator and linear slot variables. Hence, the initial embedding values
//   should match that computed from the initial values of the accumulator and
//   linear slot variables. Note that in nearly all cases, the linear value is
//   initialized to zero; this corresponds to an embedding value of zero.
//
// Performance: The TPU has to perform additional work when low dimensional
// packing is enabled. In certain situations when the vocabulary size is small,
// it may not make sense to turn on this packing since the total memory usage
// due to padding is extremely low. Hence, the TPU software automatically turns
// off the packing optimization in such scenarios.
message LowDimensionalPackingStatus {
  // if UNSPECIFIED (default), the low dimension packing status is DISABLED.
  // This can change in future.
  //
  // if ENABLED, the low dimension packing is enabled only if the following
  // three additional conditions are true:
  //  * The optimizer treats the zero gradient as a NoOp.
  //  * The embedding dimension is 1, 2, or 4.
  //  * The vocabulary size is large enough to avoid performance issues.
  //
  // if DISABLED, the low dimension packing is always disabled.
  enum Status {
    UNSPECIFIED = 0;
    ENABLED = 1;
    DISABLED = 2;
  }
}

// Configuration proto for hot ID optimization. This is an experimental feature
// that is currently disabled (by default).
message HotIdReplicationConfiguration {
  // Whether to enable or disable hot ID optimization.
  // If set to UNSPECIFIED (default), hot ID optimization is DISABLED.
  // If set to ENABLED, hot ID replication is turned ON.
  // If set to MIGRATION_ONLY, hot ID migration is turned ON.
  enum Status {
    UNSPECIFIED = 0;
    ENABLED = 1;
    DISABLED = 2;
    MIGRATION_ONLY = 3;
  }
  Status status = 1;
}

message OptimizationParameters {
  // Learning rate used for updating the embedding layer parameters.
  LearningRate learning_rate = 13;
  reserved 1;  // Old learning rate tag.

  // Limits to which to clip the weight values after the backward pass; not
  // present means no limits are applied.
  ClippingLimits clipping_limits = 2;

  // Limits to which to clip the backward pass gradient before using it for
  // updates; not present means no limits are applied.
  ClippingLimits gradient_clipping_limits = 7;

  // Amount of weight decay to apply; see weight_decay_optimizers.py for
  // details. All optimizers except MDL Adagrad Light are supported with this
  // option. Although there is no check, users who want weight decay will also
  // want to ensure that gradient accumulation is enabled so that the decay will
  // happen once per global batch.
  float weight_decay_factor = 16;

  // If true, the weight decay factor is multiplied by the current learning rate
  // before use; this is to match the note in DecoupledWeightDecayExtension in
  // weight_decay_optimizers.py.
  bool multiply_weight_decay_factor_by_learning_rate = 22;

  // Configuration for simulated quantization which is used to reduce
  // training/serving skew when the serving variables are quantized. The same
  // quantization operations are executed during training to minimize
  // differences with serving.
  SimulatedQuantization simulated_quantization = 27;

  // Status of using gradient accumulation (doing two passes over the input
  // gradients: one to accumulate them into a temporary array and another to
  // apply them using the actual optimization algorithm).
  GradientAccumulationStatus.Status gradient_accumulation_status = 17;

  // Status of the low-dimensional embedding packing optimization. This controls
  // whether to optimize the packing of 1-dimensional, 2-dimensional, and
  // 4-dimensional embedding tables in memory.
  LowDimensionalPackingStatus.Status low_dimensional_packing_status = 28;

  // Configuration proto for hot ID replication. This is an experimental
  // feature that is currently disabled (by default).
  HotIdReplicationConfiguration hot_id_replication_configuration = 18;

  // Optimization algorithm parameters; which field is selected determines which
  // algorithm to use.
  oneof parameters {
    AdagradParameters adagrad = 3;
    AdagradMomentumParameters adagrad_momentum = 26;
    BoundedAdagradParameters bounded_adagrad = 19;
    StochasticGradientDescentParameters stochastic_gradient_descent = 4;
    FtrlParameters ftrl = 5;
    AdamParameters adam = 6;
    MomentumParameters momentum = 8;
    LionParameters lion = 29;
    RmsPropParameters rms_prop = 9;
    CenteredRmsPropParameters centered_rms_prop = 10;
    MdlAdagradLightParameters mdl_adagrad_light = 11;
    AdadeltaParameters adadelta = 12;
    ProximalAdagradParameters proximal_adagrad = 14;
    OnlineYogiParameters online_yogi = 20;
    ProximalYogiParameters proximal_yogi = 21;
    FrequencyEstimatorParameters frequency_estimator = 23;
    UserDefinedProgramParameters user_defined_program = 24;
    AssignParameters assign = 25;
  }

  reserved 15;  // Old use_gradient_accumulation.

  // NEXT_ID: 30
}

// Specification of an optimization algorithm's state variables (both the main
// value vector and any extra accumulators, etc.). This proto is only used
// internally by the TPU software and is not exposed directly to the TF model.
message StateVariableSpecification {
  // Parameter name for the state variable.
  string name = 1;

  // A normal state variable that should be saved and restored in checkpoints
  // and used as an input or output to non-debug TensorFlow ops.
  message UserDefined {
    reserved 1;  // Was padding_initial_value.
  }

  // A state variable that should be filled with a constant and normally hidden
  // from users (used for intermediate gradients being accumulated, for
  // example).
  message FillWithConstant {
    double initial_value = 1;
  }

  // Usage type of this state variable.
  oneof usage {
    UserDefined user_defined = 2;
    FillWithConstant fill_with_constant = 3;
  }
}
